import numpy as np
import torch
from .dataset import Mydata
from torch.utils.data import DataLoader
from Global import *

def get_od_data():
    taxi_od = np.load(file=DATA_PATH)  
    train_data = Mydata(taxi_od)
    test_data = Mydata(taxi_od,is_train_data = False)
    traindata = DataLoader(
        dataset = train_data,
        batch_size = 6,
        shuffle = False
    )
    testdata = DataLoader(
        dataset = test_data,
        batch_size = 6,
        shuffle = False
    )
    return traindata,testdata

def get_adj_matrix(path):
    return torch.from_numpy(np.load(path)).type(torch.float32)